{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "df840597-64ce-4834-852e-48ced451f69f",
      "metadata": {
        "id": "ac9da08f7821"
      },
      "source": [
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/getting_started.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "00e00b3b-d7ed-4e2e-815e-3addfc23c8f3",
      "metadata": {
        "id": "23dd0a0eba89"
      },
      "outputs": [],
      "source": [
        "# Copyright 2024 The AI Edge Torch Authors.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a9bdc007e6ce"
      },
      "source": [
        "Note: When running notebooks in this repository with Google Colab, some users may see\n",
        "the following warning message:\n",
        "\n",
        "![Colab warning](https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/data/colab_warning.jpg?raw=true)\n",
        "\n",
        "Please click `Restart Session` and run again."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a91d40b5-91f0-4c19-bdb4-a2f56fa1c5ff",
      "metadata": {
        "id": "2b09cc13a5c1"
      },
      "outputs": [],
      "source": [
        "!pip install -r https://raw.githubusercontent.com/google-ai-edge/ai-edge-torch/main/requirements.txt\n",
        "!pip install ai-edge-torch-nightly"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f744e7c3-e360-4f3a-8d75-74759265b4aa",
      "metadata": {
        "id": "2027d669fce6"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import ai_edge_torch\n",
        "import torch\n",
        "import torchvision"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "cec203fc-7b6d-41bf-9716-4b76af45b019",
      "metadata": {
        "id": "2bf24a1bd28e"
      },
      "source": [
        "# Sample PyTorch Model\n",
        "\n",
        "Instantiate `resnet18` as a sample model from PyTorch's `torchvision` package. We also provide it with a sample input and execute it directly via PyTorch."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c4ad2105-ce72-4f00-8f4d-74b0d505d422",
      "metadata": {
        "id": "c96810c259a9"
      },
      "outputs": [],
      "source": [
        "resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()\n",
        "sample_inputs = (torch.randn(1, 3, 224, 224),)\n",
        "torch_output = resnet18(*sample_inputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "efbc9364-e0ce-4213-a0a7-07b0b6a264ae",
      "metadata": {
        "id": "ba2ad90ae477"
      },
      "source": [
        "# Conversion\n",
        "The `convert` function provided by the `ai_edge_torch` package allows conversion from a PyTorch model to an on-device model. The conversion process also requires a model's sample input for tracing and shape inference.\n",
        "\n",
        "**Note**: The source PyTorch model needs to be compliant with `torch.export` introduced in PyTorch 2.1.0 ."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "508e4a18-dd21-410c-bf65-ecb062d4d3ba",
      "metadata": {
        "id": "26a208b29579"
      },
      "outputs": [],
      "source": [
        "edge_model = ai_edge_torch.convert(resnet18, sample_inputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "35ee138e-f93d-4b47-a698-27f985fb2d3a",
      "metadata": {
        "id": "f61e660adb9f"
      },
      "source": [
        "# Inference\n",
        "Get outputs from inference with the TFLite runtime by directly calling the edge_model with the inputs. Many of the details of [TFLite inference in Python](https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python) are abstracted away with this API."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0b9860f6-0dc4-41ca-ac31-6a177e89f4c3",
      "metadata": {
        "id": "d53042b5e46a"
      },
      "outputs": [],
      "source": [
        "edge_output = edge_model(*sample_inputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "b9b6c6ca-1ebf-4011-b5cd-86a5be666f1c",
      "metadata": {
        "id": "7862f0d68600"
      },
      "source": [
        "# Validation\n",
        "Here, we make sure that the output generated by the on-device prepared model created by `ai_edge_torch` matches the output generated by PyTorch."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f2cd200f-aa9e-4eb2-94cf-9d683a01ded8",
      "metadata": {
        "id": "ea6b6914879e"
      },
      "outputs": [],
      "source": [
        "if np.allclose(torch_output.detach().numpy(), edge_output, atol=1e-5):\n",
        "    print(\"Inference result with Pytorch and TfLite was within tolerance\")\n",
        "else:\n",
        "    print(\"Something wrong with Pytorch --> TfLite\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5ee2c9f3-585a-43ff-a9ef-82e0e7b58dc3",
      "metadata": {
        "id": "83468e71907a"
      },
      "source": [
        "# Serialization\n",
        "The on-device prepared model also provides an `export` interface which can be used to serialize the model. This serializes the model as a TFLite Flatbuffers file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4df580b7-ff62-4d24-b99c-ae0c43b47a88",
      "metadata": {
        "id": "942812454807"
      },
      "outputs": [],
      "source": [
        "edge_model.export('resnet.tflite')\n",
        "\n",
        "# Download the tflite flatbuffer which can be used with the existing TfLite APIs.\n",
        "# from google.colab import files\n",
        "# files.download('resnet.tflite')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "92d06de3-2a33-4d9c-bdc0-8128379f1d6d",
      "metadata": {
        "id": "52027ca7613f"
      },
      "source": [
        "# Visualization\n",
        "The TFLite flatbuffer can be visualized using the AI Edge Model Explorer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6d04ede7-5e72-41ba-b4a8-4debe7e12507",
      "metadata": {
        "id": "1c5cc28c58de"
      },
      "outputs": [],
      "source": [
        "!pip install ai-edge-model-explorer\n",
        "\n",
        "import model_explorer\n",
        "model_explorer.visualize('resnet.tflite')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "getting_started.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
